-
Notifications
You must be signed in to change notification settings - Fork 607
Add support for SWA (left, right) with FusedAttention #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for SWA (left, right) with FusedAttention #2477
Conversation
…IA#1369 Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
Greptile Summary
Important Files Changed
Confidence score: 3/5
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)logic: Trailing comma creates single-element tuple instead of boolean - should this be just
bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]? -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)style: Uses hardcoded mask type check instead of the new
bottom_right_diagonalparameter for ALiBi alignment. Should this usebottom_right_diagonalparameter for consistency?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
15 files reviewed, 8 comments
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
| if self_attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or self_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?
| if enc_dec_attn_mask_type in {"causal", "padding_causal"}: | ||
| enc_dec_bottom_right_diagonal = False | ||
| if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| enc_dec_bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same logic issue as above - mask type check overrides explicit parameter values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Conditional logic overwrites explicitly passed parameter. When
self_attn_mask_type in {"causal", "padding_causal"}at line 777, this unconditionally setsbottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.Consider restructuring to only apply defaults when
bottom_right_diagonal is None: -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites
enc_dec_bottom_right_diagonal, ignoring explicitly passed values. -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2
15 files reviewed, 4 comments
|
/te-ci pytorch L0 |
1 similar comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean
15 files reviewed, 1 comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
Signed-off-by: Sudhakar Singh <[email protected]>
|
/te-ci pytorch L0 |
Signed-off-by: Sudhakar Singh <[email protected]>
a56c4d2 to
a545ebf
Compare
Signed-off-by: Sudhakar Singh <[email protected]>
…ansformerEngine into swa_padding_brcm_try2
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Logic overrides user's explicit parameter. If user passes
bottom_right_diagonal=Trueand mask type iscausal, line 778 forces it toFalse, then line 783 forces it back toTrueif mask type contains "bottom_right". Should only apply defaults when parameter isNone. -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same logic issue: overrides user's explicit
enc_dec_bottom_right_diagonalparameter -
transformer_engine/pytorch/attention/multi_head_attention.py, line 751-757 (link)logic: Same logic issue as transformer.py: overrides user's explicit
bottom_right_diagonalparameter -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1099-1105 (link)logic: Same logic issue: overrides user's explicit
bottom_right_diagonalparameter
17 files reviewed, 6 comments
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)syntax: Variable name mismatch - should be
use_flash_attention_2notuse_flash_attention -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)syntax: Same variable name mismatch - should be
use_flash_attention_2notuse_flash_attention -
transformer_engine/pytorch/transformer.py, line 777-778 (link)logic: Logic unconditionally overrides user-provided
bottom_right_diagonalparameter. If user explicitly passesbottom_right_diagonal=Truewithself_attn_mask_type="causal", it gets overridden toFalse. Should the mask type check only apply whenbottom_right_diagonalis None, not override explicit parameter values? -
transformer_engine/pytorch/transformer.py, line 787-788 (link)logic: Same issue - logic overrides explicit
enc_dec_bottom_right_diagonalparameter values based on mask type. Should this only apply when the parameter is None? -
transformer_engine/pytorch/attention/multi_head_attention.py, line 751-752 (link)logic: Same logic issue as in
transformer.py- mask type check overrides explicitbottom_right_diagonalvalues. The pattern of checking mask type should only apply when the parameter isNone -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1098-1099 (link)logic: Consistent with other files - mask type check overrides explicit
bottom_right_diagonalvalues instead of only applying whenNoneNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
17 files reviewed, 6 comments
|
/te-ci pytorch L1 |
Signed-off-by: Sudhakar Singh <[email protected]>
…ansformerEngine into swa_padding_brcm_try2
|
/te-ci pytorch |
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
| "Disabling FlashAttention as it only supports sliding window with bottom right" | ||
| " diagonal alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Variable name mismatch causing incorrect backend selection
This line sets use_flash_attention instead of use_flash_attention_2, which was the intended target based on the conditional check on line 906. This causes FlashAttention to be incorrectly disabled when it should only disable FlashAttention 2.
| use_flash_attention = False | |
| use_flash_attention_2 = False |
Impact: This bug causes FlashAttention to be incorrectly disabled for sliding window attention with cross-attention scenarios where bottom_right_diagonal=False and max_seqlen_q != max_seqlen_kv. The code should only disable FlashAttention 2 in this case, not all FlashAttention backends.
| "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" | ||
| " alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Variable name mismatch causing incorrect backend selection
Similar to line 911, this line sets use_flash_attention instead of use_flash_attention_2. The conditional check on line 933 verifies not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv for FlashAttention 2 with ALiBi, but then incorrectly disables all FlashAttention instead of just FlashAttention 2.
| use_flash_attention = False | |
| use_flash_attention_2 = False |
Impact: This causes FlashAttention to be incorrectly disabled for ALiBi attention with cross-attention where bottom_right_diagonal=False and max_seqlen_q != max_seqlen_kv. Only FlashAttention 2 should be disabled in this scenario.
| if attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic issue: Explicit bottom_right_diagonal parameter is unconditionally overridden
When a user explicitly passes bottom_right_diagonal=False and attn_mask_type is "causal_bottom_right", the code first sets it to False on line 1100, but then immediately overrides it to True on line 1105. This means the explicit parameter value is ignored.
Problem scenario:
# User explicitly requests bottom_right_diagonal=False
forward(..., attn_mask_type="causal_bottom_right", bottom_right_diagonal=False)
# Line 1100: bottom_right_diagonal = False (respects explicit value)
# Line 1105: bottom_right_diagonal = True (overrides explicit value!)The condition on line 1101 checks bottom_right_diagonal is None, which will be False after line 1100, but then the or clause still forces it to True.
Suggested fix: The logic should respect explicit values and only set defaults when bottom_right_diagonal is None:
if attn_mask_type in {"causal", "padding_causal"}:
if bottom_right_diagonal is None:
bottom_right_diagonal = False
elif bottom_right_diagonal is None:
# Only set to True if not explicitly specified
if attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"}:
bottom_right_diagonal = True
else:
bottom_right_diagonal = True # default for other mask typesThis same pattern appears in multiple locations and should be fixed consistently throughout the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add some explanation in the init or fwd docstring that bottom_right_diagonal will be forced to True if the mask type has bottom_right in it. Just so users (and Greptile) don't get confused. It helps remove this comment from Greptile as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right! The current behavior where bottom_right_diagonal gets overridden based on the mask type should be clearly documented. This would help users understand when their explicit parameter values will be respected versus when they'll be overridden.
Looking at the current logic, I can see that:
- Line 1100: If mask type is "causal" or "padding_causal", it's forced to
False - Lines 1101-1105: If it's
NoneOR mask type contains "bottom_right", it's forced toTrue
Here's a suggestion to add clarity to the docstring:
| if attn_mask_type in {"causal", "padding_causal"}: | |
| bottom_right_diagonal = False | |
| if bottom_right_diagonal is None or attn_mask_type in { | |
| "causal_bottom_right", | |
| "padding_causal_bottom_right", | |
| }: | |
| bottom_right_diagonal = True | |
| bottom_right_diagonal: Optional[bool], default = `None` | |
| Align sliding window and ALiBi diagonal to the top left (`False`) | |
| or bottom right (`True`) corner of the softmax matrix in the encoder. | |
| If `None`, it will be set to `False` for `attn_mask_type` = | |
| {'causal', 'padding_causal'} and `True` for other mask types. | |
| Note: This parameter will be automatically overridden based on the | |
| `attn_mask_type` - it will be forced to `False` for 'causal' and | |
| 'padding_causal' mask types, and forced to `True` for mask types | |
| containing 'bottom_right' (e.g., 'causal_bottom_right', | |
| 'padding_causal_bottom_right'), regardless of the explicitly passed value. |
This documentation makes it clear that the parameter can be overridden, helping users understand the actual behavior and avoiding confusion about when explicit values are respected.
| if attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic issue: Explicit bottom_right_diagonal parameter is unconditionally overridden
This has the same logic flaw as in dot_product_attention.py. When a user explicitly passes bottom_right_diagonal=False with attn_mask_type="causal_bottom_right", the explicit value gets overridden.
The code flow:
- Line 752: Sets
bottom_right_diagonal = Falsefor causal masks - Line 753-757: Immediately overrides to
Trueif mask type contains "bottom_right"
This means explicit parameter values are ignored. The logic should only set defaults when bottom_right_diagonal is None, not override explicit values.
| if self_attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or self_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True | ||
|
|
||
| if enc_dec_bottom_right_diagonal is None: | ||
| enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal | ||
| if enc_dec_attn_mask_type in {"causal", "padding_causal"}: | ||
| enc_dec_bottom_right_diagonal = False | ||
| if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| enc_dec_bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic issue: Explicit bottom_right_diagonal parameters are unconditionally overridden
This code has the same logic flaw as seen in other files. Both bottom_right_diagonal (lines 777-783) and enc_dec_bottom_right_diagonal (lines 787-793) suffer from the same issue where explicit parameter values get overridden.
When a user explicitly passes these parameters with specific mask types, the explicit values are ignored because:
- Lines 778, 788: Set to
Falsefor causal masks - Lines 779-783, 789-793: Immediately override to
Trueif mask type contains "bottom_right"
The logic should respect explicit values and only set defaults when the parameters are None.
|
/te-ci pytorch L1 |
|
/te-ci pytorch L1 |
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci pytorch L1 |
| _attention_backends["backend_selection_requires_update"] = True | ||
| available_backends, flash_attention_backend, fused_attention_backend = test() | ||
| if fused_attention_backend == FusedAttnBackend[backends[i]]: | ||
| fused_attn_backends.append(fused_attention_backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change intended, or it's for debugging? Would this fix the printing issue we talked about before, where the unit test wouldn't print the selection logic from this for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apparently, its use here is a bit problematic and so removing it helps print the information. (I vaguely remember that you suggested me to remove it and then I checked logging documentation and you were right)
with logging_context(highest_level=AttentionLogging._log_level):
transformer_engine/common/include/transformer_engine/fused_attn.h
Outdated
Show resolved
Hide resolved
| fu_core_attention_bias_shape = "1hss" | ||
| elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: | ||
|
|
||
| if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is alibi_slopes_shape is None branch removed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The encapsulating if condition makes sure that this branch is always false
if(
...
and (alibi_slopes_shape is not None)
):
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <[email protected]>
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
17 files reviewed, 3 comments
| in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q | ||
| + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding | ||
| window and causal mask specifically. | ||
| bottom_right_diagonal: bool, default = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Documentation inconsistency: forward function documents default as None but backward documents True
| bottom_right_diagonal: bool, default = True | |
| bottom_right_diagonal: bool, default = None |
| 0, | ||
| 0, | ||
| true, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Critical: Hardcoded parameter instead of proper API parameter
This hardcoded true value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature. The current implementation bypasses the intended parameter passing mechanism.
Issue: The function fused_attn_fp8_fwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here, rather than hardcoding true. This breaks the configurability that the PR is supposed to add. Should this hardcoded true be replaced with a bottom_right_diagonal parameter that gets passed through the function signature?
| 0, | ||
| 0, | ||
| true, | ||
| false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Critical: Hardcoded parameter instead of proper API parameter
Similar to the forward pass, this hardcoded false value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature.
Issue: The function fused_attn_fp8_bwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here. The hardcoded false suggests different behavior for backward pass, but this should still be configurable. Should this hardcoded false be replaced with a bottom_right_diagonal parameter, and why is it different from the forward pass value?
Description
FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_enginecommonfused_attnfused_attn.cppbottom_right_diagonalparameter to the APIfused_attn_f16_arbitrary_seqlen.cu: addbottom_right_diagonalparameter to the APIfused_attn_fp8.cu: addbottom_right_diagonalparameter to theFADescriptor_v1APIutils.h: addbottom_right_diagonalparameter toFADescriptor_v1APIpytorchtransformer.pybottom_right_diagonalthrough the call stack:TransformerLayer-->SelfAttention/CrossAttentionattentiondot_product_attentionbackends.py:UnfusedDotProductAttentionbottom_right_diagonalparameter to theforwardAPIforward?bottom_right_alignmentis being used in the Alibi call, perhaps this should be correctedFusedAttncustom modulebottom_right_diagonalparameter to theforwardAPIFusedAttentionmodulebottom_right_diagonalthrough the call stackdot_product_attention.pyDotProductAttentionbottom_right_diagonalthrough the call stackbottom_right_diagonalif it'sNoneutils.pyAttentionParamsget_attention_backendmulti_head_attention.pybottom_right_diagonalto forward API and callbottom_right_diagonalif it'sNonecpp_extentionsfused_attn.pybottom_right_diagonalinfused_attn_fwd/fused_attn_bwdcsrcextensionattention.cppbottom_right_diagonalthrough the call stack:fused_attn_fwd-->nvte_fused_attn_fwdextensions.hbottom_right_diagonaltofused_attn_fwdandfused_attn_bwdAPI definitionsChecklist: